{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SageMaker endpoint\n",
    "\n",
    "To deploy the model you previously trained, you need to create a Sagemaker Endpoint. This is a hosted prediction service that you can use to perform inference.\n",
    "\n",
    "## Finding the model\n",
    "\n",
    "This notebook uses a stored model if it exists. If you recently ran a training example that use the `%store%` magic, it will be restored in the next cell.\n",
    "\n",
    "Otherwise, you can pass the URI to the model file (a .tar.gz file) in the `model_data` variable.\n",
    "\n",
    "You can find your model files through the [SageMaker console](https://console.aws.amazon.com/sagemaker/home) by choosing **Training > Training jobs** in the left navigation pane. Find your recent training job, choose it, and then look for the `s3://` link in the **Output** pane. Uncomment the model_data line in the next cell that manually sets the model's URI."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrieve a saved model from a previous notebook run's stored variable\n",
    "%store -r model_data\n",
    "\n",
    "# If no model was found, set it manually here.\n",
    "# model_data = 's3://sagemaker-us-west-2-XXX/pytorch-smdataparallel-mnist-2020-10-16-17-15-16-419/output/model.tar.gz'\n",
    "\n",
    "print(\"Using this model: {}\".format(model_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a model object\n",
    "\n",
    "You define the model object by using SageMaker SDK's `PyTorchModel` and pass in the model from the `estimator` and the `entry_point`. The endpoint's entry point for inference is defined by `model_fn` as seen in the following code block that prints out `inference.py`. The function loads the model and sets it to use a GPU, if available."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pygmentize code/inference.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "role = sagemaker.get_execution_role()\n",
    "\n",
    "from sagemaker.pytorch import PyTorchModel\n",
    "model = PyTorchModel(model_data=model_data, source_dir='code',\n",
    "                        entry_point='inference.py', role=role, framework_version='1.6.0', py_version='py3')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Deploy the model on an endpoint\n",
    "\n",
    "You create a `predictor` by using the `model.deploy` function. You can optionally change both the instance count and instance type."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = model.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test the model\n",
    "You can test the depolyed model using samples from the test set.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download the test set\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "test_set = datasets.MNIST('data', download=True, train=False, \n",
    "                          transform=transforms.Compose([\n",
    "                            transforms.ToTensor(),\n",
    "                            transforms.Normalize((0.1307,), (0.3081,))\n",
    "                            ]))\n",
    "\n",
    "\n",
    "# Randomly sample 16 images from the test set\n",
    "test_loader = DataLoader(test_set, shuffle=True, batch_size=16)\n",
    "test_images, _ = iter(test_loader).next()\n",
    "\n",
    "# inspect the images\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "def imshow(img):\n",
    "    img = img.numpy()\n",
    "    img = np.transpose(img, (1, 2, 0))\n",
    "    plt.imshow(img)\n",
    "    return\n",
    "\n",
    "# unnormalize the test images for displaying\n",
    "unnorm_images = (test_images * 0.3081) + 0.1307\n",
    "\n",
    "print(\"Sampled test images: \")\n",
    "imshow(torchvision.utils.make_grid(unnorm_images))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Send the sampled images to endpoint for inference\n",
    "outputs = predictor.predict(test_images.numpy())\n",
    "predicted = np.argmax(outputs, axis=1)\n",
    "\n",
    "print(\"Predictions: \")\n",
    "print(predicted.tolist())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup\n",
    "\n",
    "If you don't intend on trying out inference or to do anything else with the endpoint, you should delete it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.delete_endpoint()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_pytorch_p36",
   "language": "python",
   "name": "conda_pytorch_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
